from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import tensorflow as tf

from onnx.backend.base import BackendRep, namedtupledict


class TensorflowRep(BackendRep):

  def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None):
    super(TensorflowRep, self).__init__()
    self._graph = graph
    self._inputs = inputs or []
    self._outputs = outputs or []
    self._tensor_dict = tensor_dict or {}

  @property
  def graph(self):
    return self._graph

  @graph.setter
  def graph(self, graph):
    self._graph = graph

  @property
  def inputs(self):
    return self._inputs

  @inputs.setter
  def inputs(self, inputs):
    self._inputs = inputs

  @property
  def outputs(self):
    return self._outputs

  @outputs.setter
  def outputs(self, outputs):
    self._outputs = outputs

  @property
  def tensor_dict(self):
    return self._tensor_dict

  @tensor_dict.setter
  def tensor_dict(self, tensor_dict):
    self._tensor_dict = tensor_dict

  def run(self, inputs, **kwargs):
    """ Run TensorflowRep.

    :param inputs: Given inputs.
    :param kwargs: Other args.
    :return: Outputs.
    """
    super(TensorflowRep, self).run(inputs, **kwargs)

    # TODO: handle name scope if necessary
    with self.graph.as_default():
      with tf.Session() as sess:
        if isinstance(inputs, dict):
          feed_dict = inputs
        elif isinstance(inputs, list) or isinstance(inputs, tuple):
          if len(self.inputs) != len(inputs):
            raise RuntimeError('Expected {} values for uninitialized '
                               'graph inputs ({}), but got {}.'.format(
                                   len(self.inputs), ', '.join(self.inputs),
                                   len(inputs)))
          feed_dict = dict(zip(self.inputs, inputs))
        else:
          # single input
          feed_dict = dict([(self.inputs[0], inputs)])

        feed_dict = {
            self.tensor_dict[key]: feed_dict[key] for key in self.inputs
        }

        sess.run(tf.global_variables_initializer())
        outputs = [self.tensor_dict[output] for output in self.outputs]

        output_values = sess.run(outputs, feed_dict=feed_dict)
        return namedtupledict('Outputs', self.outputs)(*output_values)

  def export_graph(self, path):
    """Export backend representation to a Tensorflow proto file.

    This function obtains the graph proto corresponding to the ONNX
    model associated with the backend representation and serializes
    to a protobuf file.

    :param path: The path to the output TF protobuf file.

    :returns: none.
    """
    graph_proto = self.graph.as_graph_def()
    # rename the output nodes
    meaningful_names = {}
    for output_name in self.outputs:
      meaningful_names[self.tensor_dict[output_name].name.replace(':0', '')] = output_name
    for node in graph_proto.node:
      if node.name in meaningful_names.keys():
        node.name = meaningful_names[node.name]

    file = open(path, "wb")
    file.write(graph_proto.SerializeToString())
    file.close()
